Skip to content

[mlir][linalg] Migrate Detensorize pass to new dialect conversion driver #152912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 10, 2025

The pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver.

Depends on #151865.

@llvmbot
Copy link
Member

llvmbot commented Aug 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes

The pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver.


Full diff: https://github.com/llvm/llvm-project/pull/152912.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+32-2)
  • (modified) mlir/test/Dialect/Linalg/detensorize_0d.mlir (+4-3)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..221f95a8d8f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -458,6 +458,22 @@ struct LinalgDetensorize
     }
   };
 
+  /// A listener that forwards notifyBlockErased and notifyOperationErased to
+  /// the given callbacks.
+  struct CallbackListener : public RewriterBase::Listener {
+    CallbackListener(std::function<void(Operation *op)> onOperationErased,
+                     std::function<void(Block *block)> onBlockErased)
+        : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
+
+    void notifyBlockErased(Block *block) override { onBlockErased(block); }
+    void notifyOperationErased(Operation *op) override {
+      onOperationErased(op);
+    }
+
+    std::function<void(Operation *op)> onOperationErased;
+    std::function<void(Block *block)> onBlockErased;
+  };
+
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
                                                    shouldConvertBranchOperand);
 
-    if (failed(
-            applyFullConversion(getOperation(), target, std::move(patterns))))
+    ConversionConfig config;
+    auto onOperationErased = [&](Operation *op) {
+      opsToDetensor.erase(op);
+      detensorableBranchOps.erase(op);
+    };
+    auto onBlockErased = [&](Block *block) {
+      for (BlockArgument arg : block->getArguments()) {
+        blockArgsToDetensor.erase(arg);
+      }
+    };
+    CallbackListener listener(onOperationErased, onBlockErased);
+
+    config.listener = &listener;
+    config.allowPatternRollback = false;
+    if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
+                                   config)))
       signalPassFailure();
 
     RewritePatternSet canonPatterns(context);
diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 74931cb0830bc..5c29b04630cad 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
 }
 // CHECK-LABEL: func @detensor_op_sequence
 // CHECK-SAME:    (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
-// CHECK-DAG:     %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG:     %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
 // CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
-// CHECK:         %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
-// CHECK:         %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
+// CHECK-DAG:     %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
+// CHECK:         %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
+// CHECK:         %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
 // CHECK:         %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
 // CHECK:         return %[[new_tensor_res]]

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, some nit, but I'll let others review and approve.

// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK-DAG can come in any order, but the op here specifies them in a particular order, and the two arg1 vals have the same pattern. This may randomly fail.


config.listener = &listener;
config.allowPatternRollback = false;
if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the callback doesn't return anything, it won't change the success/failure result, so this should be fine.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch from 09d8ee9 to 2422ce2 Compare August 13, 2025 13:25
@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch from 2422ce2 to 5ade1f5 Compare August 13, 2025 14:44
Base automatically changed from users/matthias-springer/prototype_one_shot to main August 13, 2025 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants